// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//


#pragma once

#include <pollux/vector/base_vector.h>
#include <pollux/vector/complex_vector.h>
#include <pollux/vector/decoded_vector.h>
#include <pollux/vector/flat_vector.h>

namespace kumo::pollux::row {
    template<TypeKind Kind>
    struct ScalarTraits {
        using InMemoryType = typename TypeTraits<Kind>::NativeType;
        using SerializedType = InMemoryType;

        static SerializedType get(
            const FlatVector<InMemoryType> &v,
            vector_size_t i) {
            return v.value_at(i);
        }

        static SerializedType get(const DecodedVector &v, vector_size_t i) {
            return v.value_at<InMemoryType>(i);
        }

        static void set(
            FlatVector<InMemoryType> *flat_vector,
            vector_size_t i,
            SerializedType val) {
            flat_vector->set(i, val);
        }
    };

    template<>
    struct ScalarTraits<TypeKind::TIMESTAMP> {
        using InMemoryType = Timestamp;
        using SerializedType = int64_t;

        static int64_t get(const FlatVector<Timestamp> &v, vector_size_t i) {
            return v.value_at(i).toMicros();
        }

        static int64_t get(const DecodedVector &v, vector_size_t i) {
            return v.value_at<Timestamp>(i).toMicros();
        }

        static void
        set(FlatVector<Timestamp> *flat_vector, vector_size_t i, SerializedType val) {
            flat_vector->set(i, Timestamp::fromMicros(val));
        }
    };

    // We deliberately do not specify the SerializedType to ensure it isn't used.
    // We don't use the `value_at` methods because they make copies of the
    // StringView. This is required in the generic interface because it's not
    // possible to return a reference to a bool.
#define SCALAR_TRAIT(PolluxKind)                                             \
  template <>                                                               \
  struct ScalarTraits<TypeKind::PolluxKind> {                                \
    using InMemoryType = StringView;                                        \
                                                                            \
    static const StringView& get(                                           \
        const FlatVector<StringView>& v,                                    \
        vector_size_t i) {                                                  \
      return v.rawValues()[i];                                              \
    }                                                                       \
                                                                            \
    static const StringView& get(const DecodedVector& v, vector_size_t i) { \
      return v.data<StringView>()[v.index(i)];                              \
    }                                                                       \
                                                                            \
    static void set(                                                        \
        FlatVector<StringView>* flat_vector,                                 \
        vector_size_t i,                                                    \
        const StringView& val) {                                            \
      flat_vector->set(i, val);                                              \
    }                                                                       \
                                                                            \
    static void setNoCopy(                                                  \
        FlatVector<StringView>* flat_vector,                                 \
        vector_size_t i,                                                    \
        const StringView& val) {                                            \
      flat_vector->setNoCopy(i, val);                                        \
    }                                                                       \
  };

    SCALAR_TRAIT(VARBINARY)

    SCALAR_TRAIT(VARCHAR)

#undef SCALAR_TRAIT

    /**
     * Supports Apache Spark UnsafeRow format. Memory management should be handled
     * by the caller, UnsafeRow does not have the notion of underlying buffer size.
     * The size function returns the number of bytes written and can be used to
     * check for buffer overflows.
     */
    class UnsafeRow {
    public:
        /**
         * UnsafeRow field width in bytes.
         */
        static const size_t kFieldWidthBytes = 8;

        /**
         * UnsafeRow class constructor.
         * @param buffer pre-allocated buffer
         * @param elementCapacity number of elements in the row
         */
        UnsafeRow(char *buffer, size_t elementCapacity)
            : buffer_(buffer), elementCapacity_(elementCapacity) {
            nullSet_ = buffer;
            size_t nullLength = getNullLength(elementCapacity);
            fixedLengthData_ = nullSet_ + nullLength;
            variableLengthOffset_ = nullLength + elementCapacity_ * kFieldWidthBytes;
        }

        /**
         * UnsafeRow constructor for recursive sub-rows in the UnsafeRow, used when
         * writing or reading a complex type variable-length element.
         * @param buffer
         * @param nullSet location of the nulls
         * @param fixedLengthData location of the fixedLengthData
         * @param elementCapacity number of elements in the row
         * @param elementCount number of elements already written
         */
        UnsafeRow(
            char *buffer,
            char *nullSet,
            char *fixedLengthData,
            size_t elementCapacity,
            size_t elementCount = 0)
            : buffer_(buffer),
              nullSet_(nullSet),
              fixedLengthData_(fixedLengthData),
              elementCapacity_(elementCapacity),
              elementCount_(elementCount) {
            variableLengthOffset_ =
                    fixedLengthData - buffer + elementCapacity * kFieldWidthBytes;
        }

        /**
         * @return the backing buffer.
         */
        char *buffer() const {
            return buffer_;
        }

        /**
         * @return the number of written elements.
         */
        size_t elementCount() const {
            return elementCount_;
        }

        /**
         * @return the element capacity.
         */
        size_t elementCapacity() const {
            return elementCapacity_;
        }

        /**
         * Increment the element count by count.  The row format is write-once
         * only so we do not need to support decrement.
         * @param count
         */
        void incrementElementCount(size_t count) {
            elementCount_ += count;
        }

        /**
         * @return pointer to the null set.
         */
        char *nullSet() const {
            return nullSet_;
        }

        /**
         * @return the size of the row in bytes
         */
        size_t size() const {
            return variableLengthOffset_;
        }

        /**
         * @return the size of the metadata, this is usually equal to the size of the
         * null set.
         */
        size_t metadataSize() const {
            return nullSet_ - buffer();
        }

        /**
         * @return the location where fixed length data is written.
         */
        char *fixedLengthDataLocation() const {
            return fixedLengthData_;
        }

        /**
         * Set element at the given position to null.
         * @param pos
         */
        void setNullAt(size_t pos) {
            bits::setBit(nullSet_, pos);
        }

        /**
         * Set element at the give position to not null.
         * @param pos
         */
        void setNotNullAt(size_t pos) {
            bits::clearBit(nullSet_, pos);
        }

        /**
         * @param pos
         * @return true if the element is null, false otherwise
         */
        bool is_null_at(size_t pos) const {
            return bits::isBitSet(nullSet_, pos);
        }

        /**
         * Reads the data at a given index.
         * @param pos
         * @param type the element type
         * @return a string_view over the data
         */
        const std::string_view readDataAt(size_t pos, const TypePtr &type) const {
            size_t cppSizeInBytes = type->isFixedWidth() ? type->cppSizeInBytes() : 0;
            return readDataAt(pos, type->isFixedWidth(), cppSizeInBytes);
        }

        /**
         * Reads the data at a given index.
         * @param pos
         * @param isFixedWidth whether the element is fixed width
         * @param fixedDataWidth 0 for variable-length data, native type width for
         * fixedLength data
         * @return a string_view of the data
         */
        const std::string_view
        readDataAt(size_t pos, bool isFixedWidth, size_t fixedDataWidth = 0) const {
            if (is_null_at(pos)) {
                return std::string_view();
            } else if (isFixedWidth) {
                return readFixedLengthDataAt(pos, fixedDataWidth);
            }
            return readVariableLengthDataAt(pos);
        }

        /**
         * @param pos
         * @param isFixedWidth
         * @return the location of the fixed size data if isFixedWidth is true,
         * return the first unwritten word for variable length data otherwise.
         */
        char *getSerializationLocation(size_t pos, bool isFixedWidth) {
            setVariableLengthOffset(alignToFieldWidth(variableLengthOffset_));
            return isFixedWidth
                       ? reinterpret_cast<char *>(&reinterpret_cast<uint64_t *>(
                           fixedLengthData_)[pos])
                       : buffer_ + variableLengthOffset_;
        }

        /**
         * Write a variable length data offset pointer at the given index.
         * @param pos
         * @param size
         */
        void writeOffsetPointer(size_t pos, size_t size) {
            POLLUX_CHECK_LE(variableLengthOffset_, UINT32_MAX);
            POLLUX_CHECK_LE(size, UINT32_MAX);
            uint64_t dataPointer = variableLengthOffset_ << 32 | size;

            // write the data pointer
            reinterpret_cast<uint64_t *>(fixedLengthData_)[pos] = dataPointer;
            setVariableLengthOffset(alignToFieldWidth(variableLengthOffset_ + size));
        }

        /**
         * @param size
         * @return size aligned to field width.
         */
        static size_t alignToFieldWidth(size_t size) {
            return bits::roundUp(size, kFieldWidthBytes);
        }

        /**
         * Returns the length of the null set in bytes.
         * @param elementCount
         * @return the null set length in bytes
         */
        static const size_t getNullLength(size_t elementCount) {
            return bits::nwords(elementCount) * kFieldWidthBytes;
        }

        /**
         * If the element is variable length, write the offset and size. For all
         * element types, set null if the element is null.
         * @param pos
         * @param dataSize the serialized data size
         * @param isFixedWidth
         */
        void writeOffsetAndNullAt(
            size_t pos,
            std::optional<size_t> dataSize,
            bool isFixedWidth) {
            if (!dataSize.has_value()) {
                setNullAt(pos);
            } else {
                setNotNullAt(pos);
                if (!isFixedWidth) {
                    writeOffsetPointer(pos, dataSize.value());
                }
            }

            incrementElementCount(1);
        }

    private:
        /*
         * Pre-allocated memory for the row.
         */
        char *buffer_;

        /**
         * Pointer to the start of the null indicators.
         */
        char *nullSet_;

        /**
         * Pointer to the start of fixed length data.
         */
        char *fixedLengthData_;

        /**
         * Offset to the start of unwritten variable length data region.
         */
        size_t variableLengthOffset_;

        /**
         * Capacity for the number of columns in the UnsafeRow.
         */
        size_t elementCapacity_;

        /**
         * Number of elements written to the Row.
         */
        size_t elementCount_ = 0;

        /**ata
         * Set variableLengthOffset_ to the given offset. Make sure the new value is
         * not smaller than the previous value, otherwise we might accidentally
         * overwrite data.
         * @param offset
         */
        void setVariableLengthOffset(size_t offset) {
            POLLUX_CHECK_GE(offset, variableLengthOffset_);
            variableLengthOffset_ = offset;
        }

        /**
         * Reads the data field as a fixed length data.
         * @param pos
         * @param width The element width in bytes
         * @return a string_view of length width
         */
        const std::string_view readFixedLengthDataAt(size_t pos, size_t width) const {
            POLLUX_CHECK_LE(width, 8);
            uint64_t *dataPointer = &reinterpret_cast<uint64_t *>(fixedLengthData_)[pos];
            return std::string_view(reinterpret_cast<char *>(dataPointer), width);
        }

        /**
         * Reads the data field as a variable length data type.
         * @param pos
         * @return a string_view over the variable length data.
         */
        const std::string_view readVariableLengthDataAt(size_t pos) const {
            // At the data field, the lower 4 bytes is size, upper 4 bytes is offset
            uint64_t *dataPointer = &reinterpret_cast<uint64_t *>(fixedLengthData_)[pos];
            uint32_t size = reinterpret_cast<uint32_t *>(dataPointer)[0];
            uint32_t offset = reinterpret_cast<uint32_t *>(dataPointer)[1];

            return std::string_view(buffer_ + offset, size);
        }
    };

    template<TypeKind kind>
    size_t serializedSizeInBytes() {
        return sizeof(typename ScalarTraits<kind>::SerializedType);
    }

    template<>
    MELON_ALWAYS_INLINE size_t serializedSizeInBytes<TypeKind::VARCHAR>() {
        POLLUX_UNREACHABLE();
    }

    template<>
    MELON_ALWAYS_INLINE size_t serializedSizeInBytes<TypeKind::VARBINARY>() {
        POLLUX_UNREACHABLE();
    }

    /// Returns the number of bytes needed to serialized fixed-width type. Throws if
    /// 'type' is not fixed-width.
    MELON_ALWAYS_INLINE size_t serializedSizeInBytes(const TypePtr &type) {
        if (type->is_unKnown()) {
            return 0;
        }
        return POLLUX_DYNAMIC_SCALAR_TYPE_DISPATCH(
                    serializedSizeInBytes, type->kind());
    }
} // namespace kumo::pollux::row
